import os
import json
import math
import h5py
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib as mpl
from torch.utils.data import TensorDataset, DataLoader
from gru_model_standalone import GRUModel, BasicGRUModel, GRUModelLayerNormChrono
import time 


# Utility: build test sequences from HDF5 

def build_test_sequences(h5_path, timesteps):
    """
    Slide a window of length = timesteps over each sample in the HDF5 file.
    HDF5 layout is assumed like:
        group[...] with datasets: 'scaled_parameters', 'normalized_current'
    """
    Xs, ys = [], []
    with h5py.File(h5_path, 'r') as f:
        for key in f.keys():
            grp = f[key]
            feats = grp['scaled_parameters'][:]      # shape (T, features)
            targ  = grp['normalized_current'][:]     # shape (T,)
            # slide a window of length=timesteps
            for j in range(len(targ) - timesteps):
                Xs.append(feats[j:j+timesteps])
                ys.append(targ[j+1:j+1+timesteps])
    return np.stack(Xs).astype(np.float32), np.stack(ys).astype(np.float32)


#  Core evaluation 

def evaluate_run(
    run_dir,
    test_h5_path,
    which_ckpt="best_by_nrmse",   # or "best_by_mse" or "last"
    batch_size=256,
    num_plot_windows=10
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"[Eval] Device: {device}")

    #  Load hparams from training run ----
    hparams_path = os.path.join(run_dir, "hparams.json")
    if not os.path.exists(hparams_path):
        raise FileNotFoundError(f"hparams.json not found in run_dir: {hparams_path}")

    with open(hparams_path, "r") as f:
        hp = json.load(f)

    model_type  = hp.get("model_type", "layernorm")  # default if older runs miss this field
    hidden_size = hp["hidden_size"]
    num_layers  = hp["num_layers"]
    dropout     = hp["dropout"]
    timesteps   = hp.get("timesteps", 160)

    print(f"[Eval] Loaded hparams from {hparams_path}")
    print(f"        model_type={model_type}, hidden_size={hidden_size}, "
          f"num_layers={num_layers}, dropout={dropout}, timesteps={timesteps}")

    # ---- 2) Build test dataset ----
    X_test, y_test = build_test_sequences(test_h5_path, timesteps)

    print(f"[Eval] Test sequences: X={X_test.shape}, y={y_test.shape}")

    test_ds = TensorDataset(torch.from_numpy(X_test), torch.from_numpy(y_test))
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=True)

    # NRMSE denominator = std of test targets
    y_test_flat = y_test.reshape(-1)
    nrmse_denom = y_test_flat.std().item()
    if nrmse_denom < 1e-8:
        nrmse_denom = 1.0
    print(f"[Eval] NRMSE denominator (std of test targets): {nrmse_denom:.6f}")

    #  Locate and load checkpoint (we load first to detect chrono) 
    if which_ckpt not in ["best_by_nrmse", "best_by_mse", "last"]:
        raise ValueError("which_ckpt must be one of: best_by_nrmse, best_by_mse, last")
    
    ckpt_path = os.path.join(run_dir, f"{which_ckpt}.pth")
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    
    print(f"[Eval] Loading checkpoint: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location=device)
    
    # accept either full ckpt dict or raw state dict
    state_dict = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt
    
    # Detect chrono model from checkpoint keys
    is_chrono_ckpt = any(k.endswith("gate_bias") for k in state_dict.keys())

    #  Build the correct model architecture 
    input_size = X_test.shape[2]
    
    if model_type == "layernorm":
        if is_chrono_ckpt:
            print("[Eval] Detected chrono checkpoint (gate_bias found). Using GRUModelLayerNormChrono.")
            model = GRUModelLayerNormChrono(
                input_size=input_size,
                hidden_size=hidden_size,
                num_layers=num_layers,
                dropout=dropout
            ).to(device)
        else:
            model = GRUModel(
                input_size=input_size,
                hidden_size=hidden_size,
                num_layers=num_layers,
                dropout=dropout
            ).to(device)
    
    elif model_type == "basic":
        model = BasicGRUModel(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout
        ).to(device)
    else:
        raise ValueError(f"Unknown model_type in hparams: {model_type}")
    
    # Now load weights
    model.load_state_dict(state_dict, strict=True)
    model.eval()

    # ---- 5) Evaluate ----
    mse_crit = nn.MSELoss()
    mae_crit = nn.L1Loss()

    total_mse = 0.0
    total_mae = 0.0
    sse_test  = 0.0
    n_test    = 0.0
    n_windows = 0

    all_preds  = []
    all_truths = []

    with torch.no_grad():
        for Xb, yb in test_loader:
            Xb = Xb.to(device)
            yb = yb.to(device)

            out = model(Xb)  # (B, T)
            mse_val = mse_crit(out, yb).item()
            mae_val = mae_crit(out, yb).item()

            total_mse += mse_val * Xb.size(0)
            total_mae += mae_val * Xb.size(0)

            err = out - yb
            sse_test += (err * err).sum().item()
            n_test   += err.numel()

            n_windows += Xb.size(0)

            all_preds.append(out.cpu().numpy())
            all_truths.append(yb.cpu().numpy())

    avg_mse = total_mse / max(n_windows, 1)
    avg_mae = total_mae / max(n_windows, 1)
    rmse_test = math.sqrt(max(sse_test / max(n_test, 1), 0.0) + 1e-12)
    nrmse_test = rmse_test / max(nrmse_denom, 1e-12)

    print(f"[Eval] Test MSE  : {avg_mse:.10f}")
    print(f"[Eval] Test MAE  : {avg_mae:.10f}")
    print(f"[Eval] Test NRMSE: {nrmse_test:.10f}")

    preds  = np.vstack(all_preds)
    truths = np.vstack(all_truths)

    #  nMAE per-window 
    def compute_nmae_new(preds, truths, eps=1e-12):
        
        """New nMAE: per-window MAE normalized by 2 * max amplitude of (truth, pred),
        then expressed in percent.
        """
        denom = np.maximum(np.max(np.abs(truths), axis=1),
                           np.max(np.abs(preds), axis=1))
        denom = np.clip(denom, eps, None)
        mae_win = np.mean(np.abs(preds - truths), axis=1)
        nmae_win = mae_win / (2.0 * denom)
        return nmae_win * 100.0  # %s

    nmae_new_per_window = compute_nmae_new(preds, truths)
    nmae_new_mean = nmae_new_per_window.mean()

    print(f"[Eval] New nMAE (mean over windows, normalized): {nmae_new_mean:.3f} %")
    print(f"[Eval] New nMAE range: {nmae_new_per_window.min():.3f} % – {nmae_new_per_window.max():.3f} %")

    # Plots & metrics saved into run_dir/eval 
    eval_dir = os.path.join(run_dir, "eval")
    os.makedirs(eval_dir, exist_ok=True)

    # Save text metrics
    metrics_text = (
        f"Test MSE: {avg_mse:.10f}\n"
        f"Test MAE: {avg_mae:.10f}\n"
        f"Test NRMSE: {nrmse_test:.10f}\n"
        f"New nMAE (mean over windows, normalized): {nmae_new_mean:.3f} %\n"
    )
    with open(os.path.join(eval_dir, "metrics.txt"), "w") as f:
        f.write(metrics_text)

    # per-window MAE/NRMSE 
    mae_list   = []
    nrmse_list = []

    for i in range(len(preds)):
        err_i  = truths[i] - preds[i]
        mae_i  = np.abs(err_i).mean()
        rmse_i = np.sqrt(np.mean(err_i**2))
        nrmse_i = rmse_i / max(nrmse_denom, 1e-12)

        mae_list.append(mae_i)
        nrmse_list.append(nrmse_i)

    mae_array   = np.array(mae_list)
    nrmse_array = np.array(nrmse_list)

    # Sorted MAE and nMAE for plots 
    sorted_indices_mae = np.argsort(mae_array)
    sorted_mae = mae_array[sorted_indices_mae]

    sorted_indices_nmae = np.argsort(nmae_new_per_window)
    sorted_nmae_new = nmae_new_per_window[sorted_indices_nmae]

    # save nMAE for later comparison

    np.save(os.path.join(eval_dir, "nmae_new_per_window.npy"), nmae_new_per_window)
    np.save(os.path.join(eval_dir, "sorted_nmae_new.npy"), sorted_nmae_new)

    # Plot sorted MAE
    plt.figure(figsize=(8, 4))
    plt.plot(sorted_mae, label="Sorted MAE")
    plt.xlabel("Sample Index (sorted by MAE)")
    plt.ylabel("MAE")
    plt.title("MAE per Test Sample (Sorted)")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(eval_dir, "sorted_mae_plot.png"), dpi=300)
    plt.close()

    # Plot sorted nMAE
    plt.figure(figsize=(8, 4))
    plt.plot(sorted_nmae_new, label="Sorted nMAE (%)")
    plt.xlabel("Sample Index (sorted by nMAE)")
    plt.ylabel("nMAE (%)")
    plt.title("nMAE per Test Sample (Sorted)")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(eval_dir, "sorted_nmae_plot.png"), dpi=300)
    plt.close()

    #  Best / worst windows by nMAE and by NRMSE 
    time_axis = np.arange(timesteps) * 0.1  # adjust if your dt is different

    worst_idx_nmae  = int(np.argmax(nmae_new_per_window))
    best_idx_nmae   = int(np.argmin(nmae_new_per_window))
    #worst_idx_nrmse = int(np.argmax(nrmse_array))

    def plot_window(idx, title_prefix, filename):
        cm = 1/2.54
        fig_w = 8.5 * cm      # 1 column width (8.5 cm)
        fig_h = 8.5 * cm      # choose a  height (adjust if needed)
        
        plt.figure(figsize=(fig_w, fig_h))
    
        #plt.rcParams["font.family"] = "Nimbus Roman"
        plt.rcParams["font.size"] = 10

        # Embed fonts properly in PDF (avoids blurry text in LaTeX / editors)
        mpl.rcParams["pdf.fonttype"] = 42
        mpl.rcParams["ps.fonttype"]  = 42
        
        plt.plot(time_axis, truths[idx], label="Ground truth",color='gray', linewidth=3, linestyle=':')
        plt.plot(time_axis, preds[idx],  label="Prediction",color="#4ED1A5", linewidth=1.5, linestyle="-")

        err = truths[idx] - preds[idx]
        mse_val  = (err**2).mean()
        rmse_val = np.sqrt(max(mse_val, 0.0))
        nrmse_val = rmse_val / max(nrmse_denom, 1e-12)
        mae_val  = np.abs(err).mean()
        nmae_val = nmae_new_per_window[idx]

        plt.title(
            f"nMAE={nmae_val:.3f}%",
            fontsize=10,
            fontfamily="Nimbus Roman"
        )
        plt.xlabel("Time [s]")
        plt.xticks(fontsize=10)
        plt.legend(fontsize=10)
        #set outer box (axes frame) linewidth ---
        ax = plt.gca()
        ax.set_ylim(-1.5, 1.5)
        # remove y ticks once
        ax.set_yticks([]) 
        ax.set_xticks(np.arange(0, time_axis[-1] + 1e-9, 5))
        
        for spine in ax.spines.values(): 
            spine.set_linewidth(0.5)
        plt.tight_layout()
        plt.savefig(os.path.join(eval_dir, filename))
        plt.close()

    plot_window(worst_idx_nmae,  "Worst Prediction by nMAE",  "worst_by_nmae.pdf")
    plot_window(best_idx_nmae,   "Best Prediction by nMAE",   "best_by_nmae.pdf")
    #plot_window(worst_idx_nrmse, "Worst Prediction by NRMSE", "worst_by_nrmse.pdf")

    #  Random sample windows for inspection
    rng = np.random.default_rng(0)
    n_total = len(preds)
    k = min(num_plot_windows, n_total)
    sample_idxs = rng.choice(n_total, size=k, replace=False)

    for idx in sample_idxs:
        plot_window(idx, f"Random sample {idx}", f"sample_{idx}.png")

    print(f"[Eval] Saved metrics and plots to {eval_dir}")


#CLI wrapper 

if __name__ == "__main__":
    import argparse

    p = argparse.ArgumentParser()
    p.add_argument("--run_dir",  type=str, required=True,
                   help="Path to training run directory (where hparams.json and *.pth live)")
    p.add_argument("--test_data", type=str, required=True,
                   help="Path to test HDF5 file")
    p.add_argument("--which_ckpt", type=str, default="best_by_nrmse",
                   choices=["best_by_nrmse", "best_by_mse", "last"],
                   help="Which checkpoint file in run_dir to evaluate")
    p.add_argument("--batch_size", type=int, default=256)
    p.add_argument("--num_plot_windows", type=int, default=10)
    args = p.parse_args()

    evaluate_run(
        run_dir=args.run_dir,
        test_h5_path=args.test_data,
        which_ckpt=args.which_ckpt,
        batch_size=args.batch_size,
        num_plot_windows=args.num_plot_windows
    )
